package com.hcloud.common.crud.service.impl;

import com.hcloud.common.core.exception.IdMissingException;
import com.hcloud.common.core.exception.ServiceException;
import com.hcloud.common.crud.entity.BaseEntity;
import com.hcloud.common.crud.repository.BaseRepository;
import com.hcloud.common.crud.service.BaseDataService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Example;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.domain.Specification;

import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.Path;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;
import javax.transaction.Transactional;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;

/**
 * @Auther hepangui
 * @Date 2018/10/31
 */
@Transactional
public abstract class BaseDataServiceImpl<Entity extends BaseEntity, Repository extends BaseRepository<Entity>> implements BaseDataService<Entity> {

    @Autowired
    protected Repository baseRepository;

    /**
     * 添加一个实体，实体内应没有id，返回添加成功后的实体
     *
     * @param entity
     * @return
     */

    @Override
    public Entity add(Entity entity) {
        entity.setCreateTime(new Date());
        return this.baseRepository.save(entity);
    }

    /**
     * 批量添加一个实体集合，返回添加成功后的集合
     *
     * @param entities
     * @return
     */
    @Override
    public List<Entity> saveAll(List<Entity> entities) {
        return this.baseRepository.saveAll(entities);
    }

    /**
     * 更新实体，必须保证id字段不为空
     *
     * @param entity
     * @return
     */
    @Override
    public Entity update(Entity entity) {
        if (entity.getId() == null || "".equals(entity.getId())) {
            throw new IdMissingException();
        }
        entity.setUpdateTime(new Date());
        return this.baseRepository.save(entity);
    }

    /**
     * 根据id获取实体，如果一级缓存中存在，则从以及缓存中取
     *
     * @param id
     * @return
     */
    @Override
    public Entity get(String id) {
        return this.baseRepository.getOne(id);
    }

    /**
     * 删除一个实体，实体必须有id
     *
     * @param entity
     * @throws ServiceException
     */
    @Override
    public void delete(Entity entity) {
        this.baseRepository.delete(entity);
    }

    /**
     * 根据id删除实体
     *
     * @param id
     * @throws ServiceException
     */
    @Override
    public void deleteById(String id) {
        this.baseRepository.deleteById(id);
    }

    @Override
    public boolean exists(String id) {
        Entity entity = this.get(id);
        if (entity == null) {
            return false;
        }
        return true;
    }

    /**
     * 根据传入的example 进行查询，不分页，且条件都是equals
     *
     * @param example
     * @return
     */
    @Override
    public List<Entity> find(Example<Entity> example) {
        return this.baseRepository.findAll(example);
    }

    /**
     * 分页查询，但是没有查询条件
     *
     * @param pageable
     * @return
     */
    @Override
    public Page<Entity> find(Pageable pageable) {
        return this.baseRepository.findAll(pageable);
    }


    /**
     * 根据传入的Bean和sort查询对应的entity
     * 不进行分页，查出来的是全部
     *
     * @param entity
     * @param sort
     * @return
     */
    @Override
    public List<Entity> findByCondition(Entity entity, Sort sort) {
        if (entity == null) {
            return new ArrayList<>();
        }

        List<Entity> entities = this.baseRepository.findAll(getEntitySpecification(entity), sort);
        return entities;
    }

    /**
     * 条件查询，根据bean中定义的字段名进行查询，
     * 具体字段命名方式参见 BaseDataServiceImpl
     *
     * @param entity 实际传入的应该是queryBean  继承自bean
     * @return
     */
    @Override
    public List<Entity> findByCondition(Entity entity) {
        if (entity == null) {
            return new ArrayList<>();
        }

        List<Entity> entities = this.baseRepository.findAll(getEntitySpecification(entity));
        return entities;
    }

    /**
     * 条件查询并分页
     *
     * @param pageable
     * @return
     */
    @Override
    public Page<Entity> findByCondition(Pageable pageable, Entity entity) {
        if (entity == null) {
            return this.find(pageable);
        }
        Page<Entity> all = this.baseRepository.findAll(getEntitySpecification(entity), pageable);
        return all;
    }

    private Specification<Entity> getEntitySpecification(Entity entity) {
        List<String> fieldList = this.getFieldList(entity.getClass());
        return (Specification<Entity>) (root, query, criteriaBuilder)
                -> {
            List<Predicate> predicates = new ArrayList<>();
            if (fieldList != null && fieldList.size() > 0) {
                for (String fieldName : fieldList) {
                    try {
                        Method m = entity.getClass().getMethod("get" + getMethodName(fieldName));
                        Object value = m.invoke(entity);
                        if (value == null) {
                            continue;
                        }
                        if (value instanceof String && "".equals((String) value)) {
                            continue;
                        }
                        Predicate predicate = this.convertPredicate(root, criteriaBuilder, fieldName, value);
                        predicates.add(predicate);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
            return criteriaBuilder.and(predicates
                    .toArray(new Predicate[]{}));
        };
    }

    private Predicate convertPredicate(Root<Entity> root, CriteriaBuilder criteriaBuilder, String fieldName, Object value) {
        if (fieldName.endsWith(POSTFIX_LIKEFONT)) {
            Path path = root.get(fieldName.replace(POSTFIX_LIKEFONT, ""));
            return criteriaBuilder.like(path, value + "%");
        }
        if (fieldName.endsWith(POSTFIX_LIKEBACK)) {
            Path path = root.get(fieldName.replace(POSTFIX_LIKEBACK, ""));
            return criteriaBuilder.like(path, "%" + value);
        }
        if (fieldName.endsWith(POSTFIX_LIKE)) {
            Path path = root.get(fieldName.replace(POSTFIX_LIKE, ""));
            return criteriaBuilder.like(path, "%" + value + "%");
        }
        if (fieldName.endsWith(POSTFIX_GT)) {
            Path path = root.get(fieldName.replace(POSTFIX_GT, ""));
            return criteriaBuilder.greaterThan(path, (Comparable) value);
        }
        if (fieldName.endsWith(POSTFIX_LT)) {
            Path path = root.get(fieldName.replace(POSTFIX_LT, ""));
            return criteriaBuilder.lessThan(path, (Comparable) value);
        }
        if (fieldName.endsWith(POSTFIX_GTEQ)) {
            Path path = root.get(fieldName.replace(POSTFIX_GTEQ, ""));
            return criteriaBuilder.greaterThanOrEqualTo(path, (Comparable) value);
        }
        if (fieldName.endsWith(POSTFIX_LTEQ)) {
            Path path = root.get(fieldName.replace(POSTFIX_LTEQ, ""));
            return criteriaBuilder.lessThanOrEqualTo(path, (Comparable) value);
        }
        if (fieldName.endsWith(POSTFIX_NOTEQUAL)) {
            Path path = root.get(fieldName.replace(POSTFIX_NOTEQUAL, ""));
            return criteriaBuilder.notEqual(path, value);
        }
        if (fieldName.endsWith(POSTFIX_ISNULL)) {
            Path path = root.get(fieldName.replace(POSTFIX_ISNULL, ""));
            return criteriaBuilder.isNull(path);
        }
        if (fieldName.endsWith(POSTFIX_ISNOTNULL)) {
            Path path = root.get(fieldName.replace(POSTFIX_ISNOTNULL, ""));
            return criteriaBuilder.isNotNull(path);
        }
        Path path = root.get(fieldName);
        return criteriaBuilder.equal(path, value);
    }

    public String getMethodName(String fildeName) throws Exception {
        byte[] items = fildeName.getBytes();
        items[0] = (byte) ((char) items[0] - 'a' + 'A');
        return new String(items);
    }

    private List<String> getFieldList(Class<?> clazz) {
        if (null == clazz) {
            return null;
        }
        List<String> fieldList = new LinkedList<String>();
        Field[] fields = clazz.getDeclaredFields();
        for (Field field : fields) {
            /** 过滤静态属性**/
            if (Modifier.isStatic(field.getModifiers())) {
                continue;
            }
            /** 过滤transient 关键字修饰的属性**/
            if (Modifier.isTransient(field.getModifiers())) {
                continue;
            }
            fieldList.add(field.getName());
        }
        /** 处理父类字段**/
        Class<?> superClass = clazz.getSuperclass();
        if (superClass.equals(Object.class)) {
            return fieldList;
        }
        fieldList.addAll(getFieldList(superClass));
        return fieldList;
    }
}
